rm(list = ls())
library(tidyverse)
library(estimatr)
library(texreg)
library(ggrepel)
library(gridExtra)
library(ggridges)
library(ggplot2)
library(ggpubr)
library(grid)
library(stargazer)
setwd(dirname(rstudioapi::getActiveDocumentContext()$path))
source('fn.R')

# Table A1: Sample size distribution
data %>% group_by(batch, arm) %>% 
  summarize(n = n()) %>% 
  pivot_wider(id_cols = arm, names_from = batch, values_from = n) %>%
  mutate_all(~ifelse(is.na(.x), 0, .x)) %>%
  group_by(arm) %>%
  rowwise() %>%
  mutate(Total = sum(c(`1`,`2`, `3`, `4`, `5`, `6`, `7`,`8`,`9`,`10`), na.rm = T)) %>%
  as.data.frame() -> x
x <- bind_rows(x, x %>% group_by() %>% summarize_all(sum))
stargazer(x, summary = F, rownames = F)

# Figure A1: Over-time posterior probability
load('data.rdata')

post <- pmat[1:(10+1),] %>%
  as.data.frame() %>%
  mutate(period = row_number()) %>%
  pivot_longer(cols = V1:V11) %>%
  mutate(name = gsub('V', 'Arm ', name))

fct_ord    <- as.vector(post %>% filter(period == 11) %>% arrange(-value) %>% dplyr::select(name))
post$p.cat <- factor(post$name, levels = fct_ord$name)
post_t     <- post[post$period == max(post$period),] 

ggplot(data = post, aes(x = period, y = value, group = name, color = p.cat, linetype = p.cat)) +
  geom_line() +
  coord_cartesian(xlim = c(0.5, (10 + 4)),  ylim = c(0, 0.4), clip = 'off') + 
  scale_x_continuous(breaks = c(seq(1, 10, 1))) +
  scale_y_continuous(breaks = c(seq(0, 1, 0.1))) +
  scale_colour_manual(
    name = 'Position in last period',
    breaks = levels(post$p.cat),
    labels = levels(post$p.cat),
    # values = RColorBrewer::brewer.pal(11, 'Paired')) +
    values = c('red', 'blue', 'forestgreen', 'orange', paste0('gray', seq(0, 99, floor(90/7))))) +
  scale_linetype_manual(name = '', 
                        values = c(rep('solid', 4), rep('longdash', 3), rep('dotdash', 2), rep('twodash', 2))) +
  ylab('Posterior probability of being the best arm') + xlab('Batch number') +
  geom_text_repel(data = post_t, aes(label = p.cat), nudge_x = 10, hjust = 1, segment.size = .2,
                  seed = 343, direction = 'y', size = 3) +
  theme_bw() + theme(legend.position = 'none',
                     panel.grid.minor = element_blank(),
                     plot.caption = element_text(hjust = 0)) -> p_posterior

pdf('Figures/Figure A1.pdf', height = 4, width = 5)
print(p_posterior)
dev.off()

load('Output/T2_simulation.rdata')
p <- out_s$d_ppmat
post <- bind_rows(
  data.frame(period = 1),
  p[31:40,1:12]) %>%
  mutate_if(is.numeric, ~ifelse(is.na(.x), 1/11, .x)) %>%
  pivot_longer(cols = X1:X11) %>%
  mutate(name = gsub('X', 'Arm ', name))
fct_ord    <- as.vector(post %>% filter(period == 11) %>% arrange(-value) %>% dplyr::select(name))
post$p.cat <- factor(post$name, levels = fct_ord$name)
post_t     <- post[post$period == max(post$period),] 
ggplot(data = post, aes(x = period, y = value, group = name, color = p.cat, linetype = p.cat)) +
  geom_line() +
  coord_cartesian(xlim = c(0.5, (10 + 4)),  ylim = c(0, 0.5), clip = 'off') + 
  scale_x_continuous(breaks = c(seq(1, 10, 1))) +
  scale_y_continuous(breaks = c(seq(0, 1, 0.1))) +
  scale_colour_manual(
    name = 'Position in last period',
    breaks = levels(post$p.cat),
    labels = levels(post$p.cat),
    # values = RColorBrewer::brewer.pal(11, 'Paired')) +
    values = c('red', 'blue', 'forestgreen', 'orange', paste0('gray', seq(0, 99, floor(90/7))))) +
  scale_linetype_manual(name = '', 
                        values = c(rep('solid', 4), rep('longdash', 3), rep('dotdash', 2), rep('twodash', 2))) +
  ylab('Posterior probability of being the best arm') + xlab('Batch number') +
  geom_text_repel(data = post_t, aes(label = p.cat), nudge_x = 10, hjust = 1, segment.size = .2,
                  seed = 343, direction = 'y', size = 3) +
  theme_bw() + theme(legend.position = 'none',
                     panel.grid.minor = element_blank(),
                     plot.caption = element_text(hjust = 0)) -> p_posterior
pdf('Figures/Figure A2.pdf', height = 4, width = 5)
print(p_posterior)
dev.off()

# Figure A2
mod_or <- lm_robust(Y ~ arm - 1, data = data %>% filter(state == 'Main state'), weights = ipw)
mod_ot <- lm_robust(Y ~ arm - 1, data = data %>% filter(state != 'Main state'), weights = ipw)

gg_df <- bind_rows(`Main state` = tidy(mod_or),
                   `Neighboring states` = tidy(mod_ot),
                   .id = 'topic') %>%
  mutate(label = gsub(pattern = 'arm', replacement = 'Arm ', x = term),
         label = fct_reorder(factor(label), estimate)) %>%
  left_join(data %>% 
              mutate(topic = ifelse(state == 'OR', 'Main state', 'Neighboring states')) %>%
              group_by(topic, arm) %>% 
              summarize(arm_n = n()) %>% 
              mutate(term = paste0('arm', arm))) %>%
  mutate(e_label = paste0(format_num(estimate, 3), " ", add_parens(std.error, 3)),
         t_label = ifelse(term == 'arm10' & topic != 'Main state', 
                          paste0('Est: ', format_num(estimate, 3), ' (SE: ', format_num(std.error, 3), ') [N = ', arm_n, ']'),
                          paste0(e_label, ' [', arm_n, ']')))

p_est <- ggplot(gg_df, aes(x = estimate, y = label, xmin = conf.low, xmax = conf.high, color = topic)) +
  geom_point(position = position_dodge(width = 0.3)) + 
  geom_errorbarh(position = position_dodge(width = 0.3), height = 0) +
  theme_bw() + 
  xlab('Average proportion of respondents supporting the measure') +
  ylab('Treatment Arm') +
  scale_color_manual(values = c('darkcyan', 'deeppink4')) + 
  geom_text(aes(label = t_label),
            size = 2.5, show.legend = FALSE, position = position_dodge(width = 0.9)) +
  theme(strip.background = element_blank(),
        axis.title.y = element_blank(),
        legend.title = element_blank(),
        legend.position = 'bottom',
        plot.title = element_text(hjust = 0.5)) +
  xlim(0, 1) 

ggsave(file="Figures/Figure A3.pdf", p_est, width = 5, height = 6.5)

